{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('1.1.0a4', '0.1.3')"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import cvxpy as cp\n",
    "import cvxpylayers\n",
    "from cvxpylayers.torch import CvxpyLayer\n",
    "from algorithms import fit\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from latexify import latexify\n",
    "cp.__version__, cvxpylayers.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define convex optimization model\n",
    "m = 10\n",
    "\n",
    "y = cp.Variable(m)\n",
    "units = cp.Variable(m)\n",
    "t = cp.Variable(m)\n",
    "alpha = cp.Parameter(m, nonneg=True)\n",
    "inverse_p = cp.Parameter(m, nonneg=True)\n",
    "B = cp.Parameter(nonneg=True)\n",
    "\n",
    "objective = cp.sum(t)\n",
    "constraints = [cp.sum(y) == B, y >= 0,\n",
    "               -cp.exp(-cp.multiply(alpha, units)) >= cp.multiply(alpha, t),\n",
    "               units == cp.multiply(y, inverse_p)]\n",
    "prob = cp.Problem(cp.Maximize(objective), constraints)\n",
    "layer = CvxpyLayer(prob, [B, inverse_p, alpha], [y])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 2.9211e-01,  4.5316e-01,  9.4326e-01,  1.9860e-01,  7.3760e-01,\n",
       "          3.1330e-01,  1.3644e+00,  4.8556e-01,  5.5352e-01,  5.6666e-01],\n",
       "        [ 1.3083e-01, -1.1127e-08,  3.7344e-01,  1.8442e-08,  2.2681e-01,\n",
       "          1.1067e-08,  1.0708e-08,  1.7964e-08, -1.8988e-08,  3.0709e-01],\n",
       "        [ 3.2713e-01,  6.7104e-01,  6.5384e-01,  5.6062e-01,  8.3935e-01,\n",
       "          3.0430e-01,  2.0327e+00,  9.8297e-01,  4.8012e-01,  2.0594e+00]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Get data\n",
    "def get_data(N, m, alpha):\n",
    "    P = torch.rand(N, m)\n",
    "    B = torch.rand(N) * 10\n",
    "    Y = layer(B, 1 / P, alpha)[0]\n",
    "    Y = Y * (torch.rand_like(Y) + .5)\n",
    "    Y = Y * B[:,None] / Y.sum(1)[:,None]\n",
    "    return torch.cat([B.unsqueeze(1), P], axis=1), Y\n",
    "\n",
    "torch.manual_seed(0)\n",
    "alpha_true = torch.rand(m)\n",
    "X, Y = get_data(100, m, alpha_true)\n",
    "Xval, Yval = get_data(50, m, alpha_true)\n",
    "Y[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0321) tensor(0.0282)\n",
      "tensor(0.8471, grad_fn=<MeanBackward0>) tensor(0.7306, grad_fn=<MeanBackward0>)\n",
      "001 | 0.73061\n",
      "batch 001 / 025 | 0.95666\n",
      "batch 002 / 025 | 1.00477\n",
      "batch 003 / 025 | 0.94091\n",
      "batch 004 / 025 | 0.84717\n",
      "batch 005 / 025 | 0.89175\n",
      "batch 006 / 025 | 0.85781\n",
      "batch 007 / 025 | 0.78630\n",
      "batch 008 / 025 | 0.71243\n",
      "batch 009 / 025 | 0.75978\n",
      "batch 010 / 025 | 0.71935\n",
      "batch 011 / 025 | 0.68150\n",
      "batch 012 / 025 | 0.64122\n",
      "batch 013 / 025 | 0.62107\n",
      "batch 014 / 025 | 0.59418\n",
      "batch 015 / 025 | 0.57193\n",
      "batch 016 / 025 | 0.56006\n",
      "batch 017 / 025 | 0.54634\n",
      "batch 018 / 025 | 0.52650\n",
      "batch 019 / 025 | 0.52287\n",
      "batch 020 / 025 | 0.50262\n",
      "batch 021 / 025 | 0.49274\n",
      "batch 022 / 025 | 0.48326\n",
      "batch 023 / 025 | 0.46799\n",
      "batch 024 / 025 | 0.45285\n",
      "batch 025 / 025 | 0.43988\n",
      "002 | 0.16683\n",
      "batch 001 / 025 | 0.28635\n",
      "batch 002 / 025 | 0.22066\n",
      "batch 003 / 025 | 0.24241\n",
      "batch 004 / 025 | 0.24119\n",
      "batch 005 / 025 | 0.21740\n",
      "batch 006 / 025 | 0.20846\n",
      "batch 007 / 025 | 0.19698\n",
      "batch 008 / 025 | 0.18614\n",
      "batch 009 / 025 | 0.17273\n",
      "batch 010 / 025 | 0.17029\n",
      "batch 011 / 025 | 0.15601\n",
      "batch 012 / 025 | 0.15188\n",
      "batch 013 / 025 | 0.15734\n",
      "batch 014 / 025 | 0.15144\n",
      "batch 015 / 025 | 0.15093\n",
      "batch 016 / 025 | 0.15447\n",
      "batch 017 / 025 | 0.15352\n",
      "batch 018 / 025 | 0.15201\n",
      "batch 019 / 025 | 0.14854\n",
      "batch 020 / 025 | 0.14326\n",
      "batch 021 / 025 | 0.14464\n",
      "batch 022 / 025 | 0.14080\n",
      "batch 023 / 025 | 0.13998\n",
      "batch 024 / 025 | 0.13892\n",
      "batch 025 / 025 | 0.13704\n",
      "003 | 0.09026\n",
      "batch 001 / 025 | 0.03625\n",
      "batch 002 / 025 | 0.07777\n",
      "batch 003 / 025 | 0.09149\n",
      "batch 004 / 025 | 0.10414\n",
      "batch 005 / 025 | 0.09634\n",
      "batch 006 / 025 | 0.09303\n",
      "batch 007 / 025 | 0.10193\n",
      "batch 008 / 025 | 0.10383\n",
      "batch 009 / 025 | 0.09616\n",
      "batch 010 / 025 | 0.09477\n",
      "batch 011 / 025 | 0.09227\n",
      "batch 012 / 025 | 0.09679\n",
      "batch 013 / 025 | 0.09136\n",
      "batch 014 / 025 | 0.09263\n",
      "batch 015 / 025 | 0.09307\n",
      "batch 016 / 025 | 0.09168\n",
      "batch 017 / 025 | 0.08879\n",
      "batch 018 / 025 | 0.09044\n",
      "batch 019 / 025 | 0.08978\n",
      "batch 020 / 025 | 0.09221\n",
      "batch 021 / 025 | 0.09078\n",
      "batch 022 / 025 | 0.08785\n",
      "batch 023 / 025 | 0.08942\n",
      "batch 024 / 025 | 0.08962\n",
      "batch 025 / 025 | 0.09214\n",
      "004 | 0.07297\n",
      "batch 001 / 025 | 0.07017\n",
      "batch 002 / 025 | 0.10268\n",
      "batch 003 / 025 | 0.10007\n",
      "batch 004 / 025 | 0.09278\n",
      "batch 005 / 025 | 0.08467\n",
      "batch 006 / 025 | 0.07947\n",
      "batch 007 / 025 | 0.07951\n",
      "batch 008 / 025 | 0.08184\n",
      "batch 009 / 025 | 0.07663\n",
      "batch 010 / 025 | 0.07075\n",
      "batch 011 / 025 | 0.07605\n",
      "batch 012 / 025 | 0.07518\n",
      "batch 013 / 025 | 0.07536\n",
      "batch 014 / 025 | 0.07888\n",
      "batch 015 / 025 | 0.07791\n",
      "batch 016 / 025 | 0.07644\n",
      "batch 017 / 025 | 0.07730\n",
      "batch 018 / 025 | 0.07525\n",
      "batch 019 / 025 | 0.08048\n",
      "batch 020 / 025 | 0.07777\n",
      "batch 021 / 025 | 0.08034\n",
      "batch 022 / 025 | 0.07846\n",
      "batch 023 / 025 | 0.07921\n",
      "batch 024 / 025 | 0.07841\n",
      "batch 025 / 025 | 0.07801\n",
      "005 | 0.06297\n",
      "batch 001 / 025 | 0.06807\n",
      "batch 002 / 025 | 0.07910\n",
      "batch 003 / 025 | 0.09111\n",
      "batch 004 / 025 | 0.08107\n",
      "batch 005 / 025 | 0.07136\n",
      "batch 006 / 025 | 0.06902\n",
      "batch 007 / 025 | 0.06606\n",
      "batch 008 / 025 | 0.06582\n",
      "batch 009 / 025 | 0.06273\n",
      "batch 010 / 025 | 0.06248\n",
      "batch 011 / 025 | 0.06317\n",
      "batch 012 / 025 | 0.06053\n",
      "batch 013 / 025 | 0.05991\n",
      "batch 014 / 025 | 0.06056\n",
      "batch 015 / 025 | 0.06498\n",
      "batch 016 / 025 | 0.06699\n",
      "batch 017 / 025 | 0.06871\n",
      "batch 018 / 025 | 0.06929\n",
      "batch 019 / 025 | 0.06978\n",
      "batch 020 / 025 | 0.06948\n",
      "batch 021 / 025 | 0.06884\n",
      "batch 022 / 025 | 0.06657\n",
      "batch 023 / 025 | 0.06554\n",
      "batch 024 / 025 | 0.06802\n",
      "batch 025 / 025 | 0.06975\n",
      "006 | 0.05528\n",
      "batch 001 / 025 | 0.06987\n",
      "batch 002 / 025 | 0.06583\n",
      "batch 003 / 025 | 0.06356\n",
      "batch 004 / 025 | 0.06702\n",
      "batch 005 / 025 | 0.06247\n",
      "batch 006 / 025 | 0.06740\n",
      "batch 007 / 025 | 0.07504\n",
      "batch 008 / 025 | 0.07623\n",
      "batch 009 / 025 | 0.07920\n",
      "batch 010 / 025 | 0.07598\n",
      "batch 011 / 025 | 0.07749\n",
      "batch 012 / 025 | 0.07553\n",
      "batch 013 / 025 | 0.07321\n",
      "batch 014 / 025 | 0.07264\n",
      "batch 015 / 025 | 0.07471\n",
      "batch 016 / 025 | 0.07691\n",
      "batch 017 / 025 | 0.07355\n",
      "batch 018 / 025 | 0.07382\n",
      "batch 019 / 025 | 0.07497\n",
      "batch 020 / 025 | 0.07185\n",
      "batch 021 / 025 | 0.07025\n",
      "batch 022 / 025 | 0.06842\n",
      "batch 023 / 025 | 0.06695\n",
      "batch 024 / 025 | 0.06526\n",
      "batch 025 / 025 | 0.06388\n",
      "007 | 0.05114\n",
      "batch 001 / 025 | 0.05940\n",
      "batch 002 / 025 | 0.05873\n",
      "batch 003 / 025 | 0.05460\n",
      "batch 004 / 025 | 0.06670\n",
      "batch 005 / 025 | 0.06367\n",
      "batch 006 / 025 | 0.06688\n",
      "batch 007 / 025 | 0.06048\n",
      "batch 008 / 025 | 0.05738\n",
      "batch 009 / 025 | 0.05766\n",
      "batch 010 / 025 | 0.05435\n",
      "batch 011 / 025 | 0.05218\n",
      "batch 012 / 025 | 0.05446\n",
      "batch 013 / 025 | 0.05260\n",
      "batch 014 / 025 | 0.05569\n",
      "batch 015 / 025 | 0.05592\n",
      "batch 016 / 025 | 0.05327\n",
      "batch 017 / 025 | 0.05366\n",
      "batch 018 / 025 | 0.05448\n",
      "batch 019 / 025 | 0.05554\n",
      "batch 020 / 025 | 0.05465\n",
      "batch 021 / 025 | 0.05554\n",
      "batch 022 / 025 | 0.05699\n",
      "batch 023 / 025 | 0.05804\n",
      "batch 024 / 025 | 0.05673\n",
      "batch 025 / 025 | 0.05698\n",
      "008 | 0.04549\n",
      "batch 001 / 025 | 0.01897\n",
      "batch 002 / 025 | 0.04835\n",
      "batch 003 / 025 | 0.04405\n",
      "batch 004 / 025 | 0.04311\n",
      "batch 005 / 025 | 0.05012\n",
      "batch 006 / 025 | 0.04597\n",
      "batch 007 / 025 | 0.04518\n",
      "batch 008 / 025 | 0.04353\n",
      "batch 009 / 025 | 0.04550\n",
      "batch 010 / 025 | 0.04657\n",
      "batch 011 / 025 | 0.04906\n",
      "batch 012 / 025 | 0.04847\n",
      "batch 013 / 025 | 0.04739\n",
      "batch 014 / 025 | 0.04811\n",
      "batch 015 / 025 | 0.04737\n",
      "batch 016 / 025 | 0.04985\n",
      "batch 017 / 025 | 0.04912\n",
      "batch 018 / 025 | 0.05182\n",
      "batch 019 / 025 | 0.05185\n",
      "batch 020 / 025 | 0.05113\n",
      "batch 021 / 025 | 0.05165\n",
      "batch 022 / 025 | 0.05323\n",
      "batch 023 / 025 | 0.05381\n",
      "batch 024 / 025 | 0.05447\n",
      "batch 025 / 025 | 0.05490\n",
      "009 | 0.04226\n",
      "batch 001 / 025 | 0.05405\n",
      "batch 002 / 025 | 0.05867\n",
      "batch 003 / 025 | 0.05000\n",
      "batch 004 / 025 | 0.04413\n",
      "batch 005 / 025 | 0.04549\n",
      "batch 006 / 025 | 0.05164\n",
      "batch 007 / 025 | 0.05519\n",
      "batch 008 / 025 | 0.05086\n",
      "batch 009 / 025 | 0.05776\n",
      "batch 010 / 025 | 0.05395\n",
      "batch 011 / 025 | 0.05519\n",
      "batch 012 / 025 | 0.05711\n",
      "batch 013 / 025 | 0.05819\n",
      "batch 014 / 025 | 0.05597\n",
      "batch 015 / 025 | 0.05461\n",
      "batch 016 / 025 | 0.05343\n",
      "batch 017 / 025 | 0.05332\n",
      "batch 018 / 025 | 0.05241\n",
      "batch 019 / 025 | 0.05222\n",
      "batch 020 / 025 | 0.05140\n",
      "batch 021 / 025 | 0.05180\n",
      "batch 022 / 025 | 0.05096\n",
      "batch 023 / 025 | 0.05138\n",
      "batch 024 / 025 | 0.04988\n",
      "batch 025 / 025 | 0.05016\n",
      "010 | 0.04157\n",
      "batch 001 / 025 | 0.06785\n",
      "batch 002 / 025 | 0.04026\n",
      "batch 003 / 025 | 0.05005\n",
      "batch 004 / 025 | 0.04866\n",
      "batch 005 / 025 | 0.04569\n",
      "batch 006 / 025 | 0.04358\n",
      "batch 007 / 025 | 0.04160\n",
      "batch 008 / 025 | 0.04378\n",
      "batch 009 / 025 | 0.04485\n",
      "batch 010 / 025 | 0.04618\n",
      "batch 011 / 025 | 0.04463\n",
      "batch 012 / 025 | 0.04227\n",
      "batch 013 / 025 | 0.04406\n",
      "batch 014 / 025 | 0.04456\n",
      "batch 015 / 025 | 0.04840\n",
      "batch 016 / 025 | 0.04842\n",
      "batch 017 / 025 | 0.04937\n",
      "batch 018 / 025 | 0.04869\n",
      "batch 019 / 025 | 0.04873\n",
      "batch 020 / 025 | 0.04878\n",
      "batch 021 / 025 | 0.04983\n",
      "batch 022 / 025 | 0.05007\n",
      "batch 023 / 025 | 0.04886\n",
      "batch 024 / 025 | 0.04894\n",
      "batch 025 / 025 | 0.04870\n",
      "011 | 0.03839\n",
      "batch 001 / 025 | 0.05088\n",
      "batch 002 / 025 | 0.06868\n",
      "batch 003 / 025 | 0.07717\n",
      "batch 004 / 025 | 0.05892\n",
      "batch 005 / 025 | 0.05794\n",
      "batch 006 / 025 | 0.05520\n",
      "batch 007 / 025 | 0.05639\n",
      "batch 008 / 025 | 0.05790\n",
      "batch 009 / 025 | 0.05572\n",
      "batch 010 / 025 | 0.05257\n",
      "batch 011 / 025 | 0.05002\n",
      "batch 012 / 025 | 0.04820\n",
      "batch 013 / 025 | 0.04771\n",
      "batch 014 / 025 | 0.04784\n",
      "batch 015 / 025 | 0.04624\n",
      "batch 016 / 025 | 0.04446\n",
      "batch 017 / 025 | 0.04346\n",
      "batch 018 / 025 | 0.04459\n",
      "batch 019 / 025 | 0.04522\n",
      "batch 020 / 025 | 0.04466\n",
      "batch 021 / 025 | 0.04465\n",
      "batch 022 / 025 | 0.04604\n",
      "batch 023 / 025 | 0.04501\n",
      "batch 024 / 025 | 0.04447\n",
      "batch 025 / 025 | 0.04611\n",
      "012 | 0.03647\n",
      "batch 001 / 025 | 0.05069\n",
      "batch 002 / 025 | 0.05915\n",
      "batch 003 / 025 | 0.04944\n",
      "batch 004 / 025 | 0.04977\n",
      "batch 005 / 025 | 0.04787\n",
      "batch 006 / 025 | 0.04657\n",
      "batch 007 / 025 | 0.04489\n",
      "batch 008 / 025 | 0.04125\n",
      "batch 009 / 025 | 0.04142\n",
      "batch 010 / 025 | 0.04092\n",
      "batch 011 / 025 | 0.04073\n",
      "batch 012 / 025 | 0.03948\n",
      "batch 013 / 025 | 0.04010\n",
      "batch 014 / 025 | 0.03971\n",
      "batch 015 / 025 | 0.04149\n",
      "batch 016 / 025 | 0.04251\n",
      "batch 017 / 025 | 0.04406\n",
      "batch 018 / 025 | 0.04364\n",
      "batch 019 / 025 | 0.04522\n",
      "batch 020 / 025 | 0.04580\n",
      "batch 021 / 025 | 0.04478\n",
      "batch 022 / 025 | 0.04329\n",
      "batch 023 / 025 | 0.04380\n",
      "batch 024 / 025 | 0.04396\n",
      "batch 025 / 025 | 0.04394\n",
      "013 | 0.03699\n",
      "batch 001 / 025 | 0.04583\n",
      "batch 002 / 025 | 0.03920\n",
      "batch 003 / 025 | 0.04782\n",
      "batch 004 / 025 | 0.04808\n",
      "batch 005 / 025 | 0.05121\n",
      "batch 006 / 025 | 0.04888\n",
      "batch 007 / 025 | 0.04393\n",
      "batch 008 / 025 | 0.03954\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch 009 / 025 | 0.03986\n",
      "batch 010 / 025 | 0.04068\n",
      "batch 011 / 025 | 0.04515\n",
      "batch 012 / 025 | 0.04370\n",
      "batch 013 / 025 | 0.04338\n",
      "batch 014 / 025 | 0.04316\n",
      "batch 015 / 025 | 0.04399\n",
      "batch 016 / 025 | 0.04235\n",
      "batch 017 / 025 | 0.04118\n",
      "batch 018 / 025 | 0.04234\n",
      "batch 019 / 025 | 0.04221\n",
      "batch 020 / 025 | 0.04215\n",
      "batch 021 / 025 | 0.04222\n",
      "batch 022 / 025 | 0.04340\n",
      "batch 023 / 025 | 0.04402\n",
      "batch 024 / 025 | 0.04320\n",
      "batch 025 / 025 | 0.04297\n",
      "014 | 0.03391\n",
      "batch 001 / 025 | 0.03546\n",
      "batch 002 / 025 | 0.03748\n",
      "batch 003 / 025 | 0.03432\n",
      "batch 004 / 025 | 0.02974\n",
      "batch 005 / 025 | 0.03182\n",
      "batch 006 / 025 | 0.03268\n",
      "batch 007 / 025 | 0.03313\n",
      "batch 008 / 025 | 0.03276\n",
      "batch 009 / 025 | 0.03451\n",
      "batch 010 / 025 | 0.03531\n",
      "batch 011 / 025 | 0.03579\n",
      "batch 012 / 025 | 0.03612\n",
      "batch 013 / 025 | 0.03635\n",
      "batch 014 / 025 | 0.03667\n",
      "batch 015 / 025 | 0.03663\n",
      "batch 016 / 025 | 0.03485\n",
      "batch 017 / 025 | 0.03848\n",
      "batch 018 / 025 | 0.03850\n",
      "batch 019 / 025 | 0.04014\n",
      "batch 020 / 025 | 0.04043\n",
      "batch 021 / 025 | 0.04063\n",
      "batch 022 / 025 | 0.03984\n",
      "batch 023 / 025 | 0.03987\n",
      "batch 024 / 025 | 0.03940\n",
      "batch 025 / 025 | 0.04107\n",
      "015 | 0.03359\n",
      "batch 001 / 025 | 0.01735\n",
      "batch 002 / 025 | 0.03175\n",
      "batch 003 / 025 | 0.03028\n",
      "batch 004 / 025 | 0.03346\n",
      "batch 005 / 025 | 0.03724\n",
      "batch 006 / 025 | 0.04019\n",
      "batch 007 / 025 | 0.04321\n",
      "batch 008 / 025 | 0.04033\n",
      "batch 009 / 025 | 0.03845\n",
      "batch 010 / 025 | 0.03828\n",
      "batch 011 / 025 | 0.03777\n",
      "batch 012 / 025 | 0.03596\n",
      "batch 013 / 025 | 0.03756\n",
      "batch 014 / 025 | 0.03729\n",
      "batch 015 / 025 | 0.03816\n",
      "batch 016 / 025 | 0.03785\n",
      "batch 017 / 025 | 0.03880\n",
      "batch 018 / 025 | 0.03956\n",
      "batch 019 / 025 | 0.03936\n",
      "batch 020 / 025 | 0.03866\n",
      "batch 021 / 025 | 0.03792\n",
      "batch 022 / 025 | 0.03984\n",
      "batch 023 / 025 | 0.03988\n",
      "batch 024 / 025 | 0.04014\n",
      "batch 025 / 025 | 0.04033\n",
      "016 | 0.03269\n",
      "batch 001 / 025 | 0.02684\n",
      "batch 002 / 025 | 0.02960\n",
      "batch 003 / 025 | 0.03250\n",
      "batch 004 / 025 | 0.04086\n",
      "batch 005 / 025 | 0.04064\n",
      "batch 006 / 025 | 0.03940\n",
      "batch 007 / 025 | 0.03829\n",
      "batch 008 / 025 | 0.03741\n",
      "batch 009 / 025 | 0.03854\n",
      "batch 010 / 025 | 0.03942\n",
      "batch 011 / 025 | 0.04103\n",
      "batch 012 / 025 | 0.04009\n",
      "batch 013 / 025 | 0.04021\n",
      "batch 014 / 025 | 0.03941\n",
      "batch 015 / 025 | 0.03988\n",
      "batch 016 / 025 | 0.04228\n",
      "batch 017 / 025 | 0.04052\n",
      "batch 018 / 025 | 0.04263\n",
      "batch 019 / 025 | 0.04152\n",
      "batch 020 / 025 | 0.04146\n",
      "batch 021 / 025 | 0.04027\n",
      "batch 022 / 025 | 0.04183\n",
      "batch 023 / 025 | 0.04104\n",
      "batch 024 / 025 | 0.04013\n",
      "batch 025 / 025 | 0.04012\n",
      "017 | 0.03441\n",
      "batch 001 / 025 | 0.02279\n",
      "batch 002 / 025 | 0.02426\n",
      "batch 003 / 025 | 0.02588\n",
      "batch 004 / 025 | 0.03290\n",
      "batch 005 / 025 | 0.02912\n",
      "batch 006 / 025 | 0.03279\n",
      "batch 007 / 025 | 0.03012\n",
      "batch 008 / 025 | 0.03018\n",
      "batch 009 / 025 | 0.03364\n",
      "batch 010 / 025 | 0.03534\n",
      "batch 011 / 025 | 0.03616\n",
      "batch 012 / 025 | 0.03511\n",
      "batch 013 / 025 | 0.03539\n",
      "batch 014 / 025 | 0.03499\n",
      "batch 015 / 025 | 0.03595\n",
      "batch 016 / 025 | 0.03694\n",
      "batch 017 / 025 | 0.03729\n",
      "batch 018 / 025 | 0.03735\n",
      "batch 019 / 025 | 0.03681\n",
      "batch 020 / 025 | 0.04052\n",
      "batch 021 / 025 | 0.03880\n",
      "batch 022 / 025 | 0.03809\n",
      "batch 023 / 025 | 0.03824\n",
      "batch 024 / 025 | 0.03931\n",
      "batch 025 / 025 | 0.03901\n",
      "018 | 0.03100\n",
      "batch 001 / 025 | 0.07700\n",
      "batch 002 / 025 | 0.06152\n",
      "batch 003 / 025 | 0.05033\n",
      "batch 004 / 025 | 0.04103\n",
      "batch 005 / 025 | 0.04178\n",
      "batch 006 / 025 | 0.03936\n",
      "batch 007 / 025 | 0.03917\n",
      "batch 008 / 025 | 0.04117\n",
      "batch 009 / 025 | 0.04004\n",
      "batch 010 / 025 | 0.03987\n",
      "batch 011 / 025 | 0.04580\n",
      "batch 012 / 025 | 0.04589\n",
      "batch 013 / 025 | 0.04529\n",
      "batch 014 / 025 | 0.04454\n",
      "batch 015 / 025 | 0.04522\n",
      "batch 016 / 025 | 0.04341\n",
      "batch 017 / 025 | 0.04234\n",
      "batch 018 / 025 | 0.04260\n",
      "batch 019 / 025 | 0.04204\n",
      "batch 020 / 025 | 0.04206\n",
      "batch 021 / 025 | 0.04082\n",
      "batch 022 / 025 | 0.04002\n",
      "batch 023 / 025 | 0.03949\n",
      "batch 024 / 025 | 0.04042\n",
      "batch 025 / 025 | 0.03952\n",
      "019 | 0.03096\n",
      "batch 001 / 025 | 0.01809\n",
      "batch 002 / 025 | 0.02567\n",
      "batch 003 / 025 | 0.03457\n",
      "batch 004 / 025 | 0.03546\n",
      "batch 005 / 025 | 0.03506\n",
      "batch 006 / 025 | 0.03582\n",
      "batch 007 / 025 | 0.03982\n",
      "batch 008 / 025 | 0.04147\n",
      "batch 009 / 025 | 0.03908\n",
      "batch 010 / 025 | 0.04078\n",
      "batch 011 / 025 | 0.03942\n",
      "batch 012 / 025 | 0.03711\n",
      "batch 013 / 025 | 0.03565\n",
      "batch 014 / 025 | 0.03512\n",
      "batch 015 / 025 | 0.03771\n",
      "batch 016 / 025 | 0.03719\n",
      "batch 017 / 025 | 0.03885\n",
      "batch 018 / 025 | 0.03889\n",
      "batch 019 / 025 | 0.03880\n",
      "batch 020 / 025 | 0.04062\n",
      "batch 021 / 025 | 0.04026\n",
      "batch 022 / 025 | 0.03892\n",
      "batch 023 / 025 | 0.03788\n",
      "batch 024 / 025 | 0.03703\n",
      "batch 025 / 025 | 0.03765\n",
      "020 | 0.03117\n",
      "batch 001 / 025 | 0.04205\n",
      "batch 002 / 025 | 0.05753\n",
      "batch 003 / 025 | 0.05754\n",
      "batch 004 / 025 | 0.04898\n",
      "batch 005 / 025 | 0.04493\n",
      "batch 006 / 025 | 0.04862\n",
      "batch 007 / 025 | 0.04620\n",
      "batch 008 / 025 | 0.04389\n",
      "batch 009 / 025 | 0.04271\n",
      "batch 010 / 025 | 0.04287\n",
      "batch 011 / 025 | 0.04083\n",
      "batch 012 / 025 | 0.03941\n",
      "batch 013 / 025 | 0.03876\n",
      "batch 014 / 025 | 0.03954\n",
      "batch 015 / 025 | 0.03865\n",
      "batch 016 / 025 | 0.03715\n",
      "batch 017 / 025 | 0.03654\n",
      "batch 018 / 025 | 0.03655\n",
      "batch 019 / 025 | 0.03816\n",
      "batch 020 / 025 | 0.03823\n",
      "batch 021 / 025 | 0.03822\n",
      "batch 022 / 025 | 0.03843\n",
      "batch 023 / 025 | 0.03809\n",
      "batch 024 / 025 | 0.03716\n",
      "batch 025 / 025 | 0.03699\n",
      "021 | 0.03079\n",
      "batch 001 / 025 | 0.00557\n",
      "batch 002 / 025 | 0.01641\n",
      "batch 003 / 025 | 0.01215\n",
      "batch 004 / 025 | 0.01772\n",
      "batch 005 / 025 | 0.02054\n",
      "batch 006 / 025 | 0.01961\n",
      "batch 007 / 025 | 0.02490\n",
      "batch 008 / 025 | 0.02573\n",
      "batch 009 / 025 | 0.02575\n",
      "batch 010 / 025 | 0.02683\n",
      "batch 011 / 025 | 0.02951\n",
      "batch 012 / 025 | 0.03016\n",
      "batch 013 / 025 | 0.03187\n",
      "batch 014 / 025 | 0.03359\n",
      "batch 015 / 025 | 0.03288\n",
      "batch 016 / 025 | 0.03313\n",
      "batch 017 / 025 | 0.03478\n",
      "batch 018 / 025 | 0.03355\n",
      "batch 019 / 025 | 0.03272\n",
      "batch 020 / 025 | 0.03421\n",
      "batch 021 / 025 | 0.03319\n",
      "batch 022 / 025 | 0.03463\n",
      "batch 023 / 025 | 0.03573\n",
      "batch 024 / 025 | 0.03684\n",
      "batch 025 / 025 | 0.03664\n",
      "022 | 0.03102\n",
      "batch 001 / 025 | 0.01782\n",
      "batch 002 / 025 | 0.01890\n",
      "batch 003 / 025 | 0.02950\n",
      "batch 004 / 025 | 0.03226\n",
      "batch 005 / 025 | 0.02933\n",
      "batch 006 / 025 | 0.03103\n",
      "batch 007 / 025 | 0.03151\n",
      "batch 008 / 025 | 0.03032\n",
      "batch 009 / 025 | 0.03146\n",
      "batch 010 / 025 | 0.03049\n",
      "batch 011 / 025 | 0.03042\n",
      "batch 012 / 025 | 0.03255\n",
      "batch 013 / 025 | 0.03496\n",
      "batch 014 / 025 | 0.03603\n",
      "batch 015 / 025 | 0.03800\n",
      "batch 016 / 025 | 0.03956\n",
      "batch 017 / 025 | 0.03871\n",
      "batch 018 / 025 | 0.03941\n",
      "batch 019 / 025 | 0.03835\n",
      "batch 020 / 025 | 0.03735\n",
      "batch 021 / 025 | 0.03662\n",
      "batch 022 / 025 | 0.03673\n",
      "batch 023 / 025 | 0.03661\n",
      "batch 024 / 025 | 0.03596\n",
      "batch 025 / 025 | 0.03597\n",
      "023 | 0.03037\n",
      "batch 001 / 025 | 0.02188\n",
      "batch 002 / 025 | 0.01612\n",
      "batch 003 / 025 | 0.02627\n",
      "batch 004 / 025 | 0.02667\n",
      "batch 005 / 025 | 0.02662\n",
      "batch 006 / 025 | 0.03005\n",
      "batch 007 / 025 | 0.02678\n",
      "batch 008 / 025 | 0.02795\n",
      "batch 009 / 025 | 0.02769\n",
      "batch 010 / 025 | 0.02644\n",
      "batch 011 / 025 | 0.02819\n",
      "batch 012 / 025 | 0.03146\n",
      "batch 013 / 025 | 0.03089\n",
      "batch 014 / 025 | 0.02924\n",
      "batch 015 / 025 | 0.03006\n",
      "batch 016 / 025 | 0.03003\n",
      "batch 017 / 025 | 0.03002\n",
      "batch 018 / 025 | 0.03321\n",
      "batch 019 / 025 | 0.03400\n",
      "batch 020 / 025 | 0.03324\n",
      "batch 021 / 025 | 0.03348\n",
      "batch 022 / 025 | 0.03402\n",
      "batch 023 / 025 | 0.03351\n",
      "batch 024 / 025 | 0.03560\n",
      "batch 025 / 025 | 0.03578\n",
      "024 | 0.02987\n",
      "batch 001 / 025 | 0.01686\n",
      "batch 002 / 025 | 0.01908\n",
      "batch 003 / 025 | 0.02229\n",
      "batch 004 / 025 | 0.02440\n",
      "batch 005 / 025 | 0.02331\n",
      "batch 006 / 025 | 0.02935\n",
      "batch 007 / 025 | 0.02666\n",
      "batch 008 / 025 | 0.03167\n",
      "batch 009 / 025 | 0.03199\n",
      "batch 010 / 025 | 0.03433\n",
      "batch 011 / 025 | 0.03243\n",
      "batch 012 / 025 | 0.03466\n",
      "batch 013 / 025 | 0.03377\n",
      "batch 014 / 025 | 0.03241\n",
      "batch 015 / 025 | 0.03151\n",
      "batch 016 / 025 | 0.03079\n",
      "batch 017 / 025 | 0.03018\n",
      "batch 018 / 025 | 0.03382\n",
      "batch 019 / 025 | 0.03428\n",
      "batch 020 / 025 | 0.03490\n",
      "batch 021 / 025 | 0.03607\n",
      "batch 022 / 025 | 0.03677\n",
      "batch 023 / 025 | 0.03631\n",
      "batch 024 / 025 | 0.03676\n",
      "batch 025 / 025 | 0.03620\n",
      "025 | 0.03098\n",
      "batch 001 / 025 | 0.02503\n",
      "batch 002 / 025 | 0.03589\n",
      "batch 003 / 025 | 0.03012\n",
      "batch 004 / 025 | 0.04582\n",
      "batch 005 / 025 | 0.04090\n",
      "batch 006 / 025 | 0.03892\n",
      "batch 007 / 025 | 0.03599\n",
      "batch 008 / 025 | 0.03890\n",
      "batch 009 / 025 | 0.04375\n",
      "batch 010 / 025 | 0.04098\n",
      "batch 011 / 025 | 0.03962\n",
      "batch 012 / 025 | 0.03806\n",
      "batch 013 / 025 | 0.03733\n",
      "batch 014 / 025 | 0.03619\n",
      "batch 015 / 025 | 0.03535\n",
      "batch 016 / 025 | 0.03568\n",
      "batch 017 / 025 | 0.03744\n",
      "batch 018 / 025 | 0.03667\n",
      "batch 019 / 025 | 0.03577\n",
      "batch 020 / 025 | 0.03519\n",
      "batch 021 / 025 | 0.03442\n",
      "batch 022 / 025 | 0.03333\n",
      "batch 023 / 025 | 0.03485\n",
      "batch 024 / 025 | 0.03546\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch 025 / 025 | 0.03574\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(1)\n",
    "alpha = torch.rand(m)\n",
    "alpha.requires_grad_(True)\n",
    "def loss(X, Y, alpha):\n",
    "    Yhat = layer(X[:,0], 1. / X[:,1:], alpha)[0]\n",
    "    return loss_fn(Y, Yhat)\n",
    "\n",
    "def callback():\n",
    "    alpha.data = torch.max(alpha.data, torch.zeros_like(alpha.data))\n",
    "\n",
    "print(loss(X, Y, alpha_true), loss(Xval, Yval, alpha_true))\n",
    "print(loss(X, Y, alpha), loss(Xval, Yval, alpha))\n",
    "\n",
    "val_losses, train_losses = fit(lambda X, Y: loss(X, Y, alpha), [alpha], X, Y, Xval, Yval,\n",
    "                               opt=torch.optim.Adam, opt_kwargs={\"lr\":1e-2},\n",
    "                               batch_size=4, epochs=25, verbose=True, callback=callback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shane/miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:9: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  if __name__ == '__main__':\n",
      "/home/shane/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py:1946: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.\n",
      "  warnings.warn(\"reduction: 'mean' divides the total loss by both the batch size and the support size.\"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(0.0506, grad_fn=<KlDivBackward>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "theta = torch.zeros(m + 1, m)\n",
    "theta0 = torch.zeros(m, requires_grad=True)\n",
    "theta.requires_grad_(True)\n",
    "opt = torch.optim.LBFGS([theta, theta0], max_iter=500)\n",
    "l = torch.nn.KLDivLoss()\n",
    "\n",
    "def closure():\n",
    "    opt.zero_grad()\n",
    "    loss = l(torch.nn.LogSoftmax()(X @ theta + theta0), Y / Y.sum(1)[:,None])\n",
    "    loss.backward()\n",
    "    return loss\n",
    "\n",
    "opt.step(closure)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shane/miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:4: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  after removing the cwd from sys.path.\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAB+CAYAAAA+/4aWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO2de1xUdfrHP18GRO7DReXiBdHVlN1NEdxXtpqbgxpqmeAtNckSFH+VrYbharZaKa5mXlZkNNdLGyF4rTRktFYzb0C62nof10JJFJxEBWGY5/fHXHbAuTN3v+/X67xm5pzv5Tlnzvc85/lenocRETgcDofDaY6HowXgcDgcjnPCFQSHw+FwdMIVBIfD4XB0whUEh8PhcHTCFQSHw+FwdOLpaAE4HHeBMSYEEA8gjoiWNtufBkAKQEpEZQ4SkcMxC25BcDhWgohkAEp0HEoDICaiQgBj7SsVh2M5TmdBhIWFUXR0tKPF4DwmlJaW3iaiNjauJkHLoogxlpi3AY49MdQGnE5BREdHo6RE10sYh2N9GGPX7FylUI8caVBaGujYsSNvAxy7YagN8C4mDsf2nGSMqS0Hqa4ERCQmongiim/TxtYGDYdjGi6jIHJzczF+/HhHi8HhGGMMgETGWIxqywQgBpDCGEsBkNuSwsvLy1FfX28NOTkcozhdF5M+pFIpduzYASICY8zR4jx2NDQ0oLy8HHV1dY4WxWJat26N9u3bw8vLy2Z1EJEYSoWgZmmzT4u5du0aYmNjMW/ePLzzzjstLY6jB3e41/VhbhtwGQURHh6O+vp6yGQyBAcHO1qcx47y8nIEBAQgOjraJRU0EaGqqgrl5eXo3Lmzo8WxiE6dOmHw4MFYuHAhxo8fj06dOjlaJLfE1e91fVjSBlymiyk8PBwA8MsvvzhYkseTuro6hIaGumyDYYwhNDTU5d8KP/74YzDG8MYbbzhaFLfFXvd6WVkZyspstyRGJpNBIpFoflvSBlxOQVRUVDhYkscXV1UOalxdfkA5w+m9997Dnj17sGfPHkeL47bY416JiYlBfn6+zcoXCoUoKChoss/c83I5BcEtCPemrKwMXbp0afLmo95XWFgIiUQCsVhsoAT3Z+bMmfjtb3+LkydPOloUk7l06RL69+8PqVTnJK7HEqFQ54xnp8KlxiAAriDcnbi4OMTExEAkEj2yLyUlBQCQmJiIMWPGuEQDswVeXl44fvw4fH19HS2KyaxatQrfffcd0tPTUVxc7GhxnAqpVKp5IRKJRJDJZCgpKYFMJkNaWhokEonmmiUmJiI/Px/p6ekoKytDZmbmI/kBQCKRICYmBtXV1S2SzWUUhFAohLe3N1cQTsLAgQONphk+fDhmz56tSZ+amorU1FTcvn0bYWFhZtcpkUhQUFCAxMTEx1Y5qFErh5KSEgQGBqJbt24Olkg/CoUCu3btAgC88sorLjMTcebMmTh16lSLyujVqxc+/vhjg2nmzJmDrKwsAEplIRKJIJVKNWMIIpEIJ0+eRHZ2NgCguLgYIpFI033UPH9BQQGys7N1djGZi8t0MTHGEB4ezhXEY4ZMJtN8F4lEmkbCAR48eIChQ4di2rRpcObQwd9//z3Ky8vx6aef4qWXXnIJ5WBvYmJiEBcXh/j4eMyZM0fzW01oaKjmu66XI+381sRlLAgAXEE4Ed9++63F6c2xHiQSiaZrCVA2jvz8fGRmZppVvzvi6+uL999/H9OnT0deXh5eeuklR4ukk6ioKEyePBnz5s1Dz549IZFIIBQKMXXqVEeLZhBjb/4tRSKRQCqVIjs7G2KxWPOA79KlC6RSKaqrq1FaWop79+5pxpskEgnKysoglUpRUlKiM/+cOXOwbds2xMTEaNLExBh1AaYbInKqrU+fPqSPF154gX73u9/pPc6xHf/5z3/sUk9paSnFxMRQcXEx5ebmUkpKimZfaWkpERGlpaVRbm4u3blzx+zym58HgBJygvteezPUBpojl8spISGB2rVrZ9H1sBczZswgANS3b18SiUQkFArp1q1bjhZLJ/a61x2FOW3A5SyI77//3tFicGxIXFwcrly5ovmdlpYGAE325ea2yFuFWyEQCJCTk4O+ffti/vz5WL16taNFasKZM2dw/fp15Ofno3379jhx4gTeffddfPPNN5g/fz5ycnIcLSLHAC4zBgEoFcStW7fQ0NDgaFE4HKehT58+mDVrFiIjIx0tyiN89NFHSE5Oxu3bt7F69Wo8++yzWLlyJV555RWIxWKcPn3a0SJyDGA1C8JQ1CyVK2MpACEpg6ZYhHqqa2VlJaKiolomMIfjRixd2mJXT1anrq4OO3bsQEREBKqqqvDcc8+hR48e+P3vf69xmfPmm2/im2++4QPXToo1LQidUbNUHiylRCRpiXIA+FoIDscQRIQdO3bgs88+c7QoAIB9+/bh7t27uH79OlJSUuDt7Y3u3bsjKysLhYWFeP311/Hhhx9y5eDEWFNBJJAy5CLQNGpWIoAYxlgKY0ykI5/JREREAOAKguOcMMaEjLFM1b0e1+xYHGNM1NI2YIw1a9ZgxowZqKystGU1JpGXl4fAwEDU1dU1mWH1zjvvoGvXrvjnP//ZZConx/kwqiAYY7MZY9GMsW2MsddMLLf5RN0SlfUwR08daYyxEsZYya1bt/QWyi0IjpOjz4oWAQARSWBCyFFLYYzh73//O+7fv4+3337bVtWYRGNjI0pKShASEoKoqCgMGDBAc6x169ZYu3YtLl26hMWLFyMjIwMffPCBA6V1HM7uesQUC+IHAOlQPtwNnY2+qFlXdCXWhkyMptWuXTsAXEG4O2VlZRCLxZBIJJBIJJr+dYlE0sQfk3oRXVlZGfr06dOksUkkEiQmJjZZaGcHdFrRKsWwnjGWC2CbLQXo0aMHZs+ejS1btuBf//qXLasyiEAgwLFjx1BeXo5x48ZBIBA0OZ6YmIjx48djyZIl+O9//4sPPvgAP/30k83kkcvlTje5RSqVorCwRb3uNsfULqZqAATAkD3YJGpWs2haItX+Fi2Dbd26NYRCIVcQboxMJkNubi7S0tIgEokgEolQVVWlcSGQkpICkUiEMWPGYM4cpUEaFxeH9PT0JtNfZTIZ4uLiHOmSQ1OxqrtpKpQvS1m6EptqRZvCvHnz0KlTJ2RkZDjsoUhE2L17N+Ryud4FfB999BF8fHxw9+5dEJHNFj/K5XK0b98eM2fOtEn5llJWVoaTJ0+irKwMhYWFSE9PR3p6uualSCaTYfTo0QCUykQsFkMsFtvV6jBlFlNvAIVQRsTary+R6s2p+VQKq0XTUhMeHs5dfjsYW/qo2bZtG/r06dNkn7rRJCYmavYJhUKUlJRofsfExKC0tBQANCtHHeDt9CRjLIaIpGhqRYuIaCmAMtVL0yOQViS6+Pj4FvnN8PX1RW5uLm7fvg1PT/svdaqqqkLfvn3h5+eH7t27o3fv3jrThYeHa7qYRo0ahfz8fGRkZDTpjrIGnp6e8PT0xNq1a7FmzRqzBsV1+RwbM2YMMjIy8ODBAyQlJT1y3FSfY2qfS2pnlGp/SzKZDMXFxRAKhQgJCQHwqL8li1dGm4mpXUxpMN7FZBe4u43Hj5iYGMhkMqOeKRMTEyGRSCCTyezWgJqhz4qWaA1c2y5CjBZDhgzBhAkTHDJDaPv27ZBKpThz5oxR30tpaWno27cvDh8+jKioKI1VaA127tyJNWvWAAD++te/AgD+85//WK18a6G2CNT+lrStXu173lb+lgxh6utFFf7XxXTQduIYJzw83KX84LsjtvRRo911pEbtj0nbUZ8uJZCSkqIx0x2BESvaLoqhOWvXrsWxY8ewZcsWu9WZl5eHNm3a4NatWxg/frzBtAKBAOvWrUN8fDyef/55fPTRR1aRYf369Zg2bRr+8Ic/YNq0aejfvz8A4MCBA4iNjTW5HEM+x3x9fQ0eN+ZzTCgUarpP1d1N2qh9Nenyt2Q39PngUG8ABkE5sFYE4Flj6Vu6GfND89Zbb5Gfn58xdyMcK2NP/zSlpaWUm5tLxcXFVFBQoNmv9s9UXFxM2dnZGt9DpaWllJKSQnfu3KHMzEwiIsrNzSWRSPSIfyJ388VkjMWLFxMA+vLLL61WpiHKy8uJMUYRERGUkJBgcr6ZM2cSY4yOHTtGCoWC6uvrLapfoVDQ+++/TwBo6NChdO/ePTp48CCFhYVRQEAAPf/880bL4L6YtJ7/+g5oEgCjdH231WascWRnZxMAqqmpseTacCzEXRrN46YgHj58SD169KDOnTvT/fv3rVauPlasWEFQ9jbQihUrTM539+5dioqKoieffJKSkpJo5syZFtU/c+ZMAkATJkyghw8f0ooVK0ggEBBjjPz9/en48eNGy3CXe10f5rQBU8YgQvV8dwh8LQSHYzqtWrXC2rVrcfXqVSxevNjm9SUkJODpp5+Gh4cHxo4dazyDioCAAKxatQqnT59GdXU11qxZg3Pnzpldf3R0NGbOnInc3Fy89tpreOuttzBixAiIxWLcu3cPcrnc7DIfa/RpDvUGoDOUXUz5AKKNpW/pZuztqaioiADQ4cOHLVOfHItwl7eqx82CUDNx4kTy9vamX375xepla6NQKKhLly4kEoksyjts2DDy9fWlwMBAGjx4MCkUCqP5ampq6MSJE5rf165do7i4OAJAf/3rX6mxsZGqq6tJIBDQ6NGjqaioyGB57nKv66PF7r4ZY72hXAlKABj+N3spG1orRB2B2oLgU105HNP529/+hilTpiAoKMhmdXz77beoqKjAlStX8Je//MXs/IwxrF69GrGxsYiOjsb+/fvx5ZdfYsSIEXrz3L59G8OGDcPFixdx9epVnD59GqNHj0ZdXR12796N559/HgAQHByMfv364YsvvkBVVRUGDx5sUJba2lq0bt3arfxEERHq6urMymNoFtNiIvpVewdjbJAlglkT3sXE4ZiHXC7H7t27sWDBAkRFRaGoqMiimODG+L//+z/cuXMH3t7eGDVqlEVldO7cGe+++y6ysrLQoUMHZGdn61UQP/30E4YMGYKrV68iLy8PW7duxVtvvYWuXbti165deOKJJ5qkT0pKwuHDh3HkyBHU1tbCx8dHZ7kRERG4fv260628tgZeXl4an3Ymoc+0cNRmzLyWy+UkEAho7ty55lpWnBZgT7P7ypUrNiv7cepiUigUtHPnTurevTsBoNjYWAJAfn5+VFFRYZU61Pz73/8mABQQEECjRo1qUVn19fUUGxtLERERdOPGDZ1pfvzxR2rfvj0FBgZScXExTZ48mQDQiBEjSCaT6cxz+vRpzQD6gQMHWiSjO2GoDbhUwCBAOW+6bdu23IJwU1zBP40rcPToUfTv3x8vvvgiPDw8sHv3bpw5cwZPPfUU7t+/r4nUZy3y8vLg4eGBmpqaFsfG9vLyQk5ODioqKrBy5UrI5XLcvXu3SZqcnBzI5XIUFBRg7ty52Lx5MxYsWIBdu3bp7Ub73e9+h4iICDDGcODAgRbJ+NigT3M4ajPl7SkuLo6SkpIs1pgc87GXBVFQUKCJQ11QUEBpaWmUlpbWZO1DSkoKESktjdzcXMrNzTXZ6nB3C+LChQuUnJxMACg8PJzEYjE1NDRojp89e5YYYwSAdu7caXE92igUCurcuTNFRUVRYGAgPXjwwCrlTpkyhQQCAfXo0YMmT55MRKRZH1FfX0+FhYXUtm1b8vf3N/lcpk6dSgKBgIYOHWoVGd0BQ23ALAuCMRZtfRVlPtzdhuMZOHDgI9vatWsBAA8ePNB5fNOmTQCUA4v6EIlESEhIQFxcHEQiEYRCIXJzcxEfH4+qqqpH/NPEx8cjPj7e6d0m25qbN29ixowZiI2NRVFRERYuXIjLly9j6tSpTfwxxcbGYsaMGQCAqVOnWsXb7ZUrV1BRUYGqqiqMGjVKb9++uWRnZ0MoFOLu3bvYvHkzsrKy8OSTT+KXX37B+vXrMW7cOAQFBeH48eMYOXKkSWUmJSWhsbHRqi493BlT4kG8zRhbxxhbB6DADjIZhSsI98eZ/dM4E/fv38eiRYvQtWtXjRfcy5cvY/78+fDz89OZZ9GiRRAKhaiursbRo0dbLEPXrl2xYcOGRwIDtZSwsDD87W9/w/Xr1xEUFIQlS5agbdu2yMzMxIwZMzBkyBCcOHECPXv2NLnMQYMGwcvLC/v27bOanG6NPtNCvUHLvQaAzsbSt3QzxbyeO3cuCQQCamxsbIFhxTEHew5SZ2ZmatxsqLuTtPfHxcXRlStX6MqVK5SdnU3FxcWPuNTQh7t0MTU0NJBYLKaIiAgCQKNGjaLz58+bdA2IiNavX08A6LPPPjM5jyGSk5OpXbt2TbqzrIFCoaABAwZQQEAATZ48mRISEggAzZs3z+L2P2jQIAoJCaH333/fqrK6KobagCkKYhuAHABLABQZS9/SzZTGsWrVKgJAN2/ebNGF4ZiOuywecnUFoVAoaPfu3dSjRw8CQP369aMjR46YfR3kcjn16dOHIiMj6fPPP7fYDcfBgwepV69e5OXlRW+++aZFZRjjxx9/JE9PT2KMkZ+fH23fvr1F5S1fvpwAULdu3awkoWtjqA2YMgaRT0TTiegdGIjrYCger+p4GmPMKtFb+FoIjjNiJCZ1miomdYql5ZeUlGDgwIF44YUX0NjYiJ07d+K7775Dv379zC5LIBBg1apVuHHjBsaNG4f33nvPIpk+++wznD9/Hg0NDVbtXtKmZ8+e+OCDD/Dkk0/i2LFjFq+xUKOO4XDx4kX+DDGCUQVBRNu1vhuaG6YzHi+gbDgAEgGEWChnE7iC4Dgp+mJSpwCQEpFEdcwizp49i/PnzyMnJwdnz57FyJEjW7TSt1+/fpg0aRI8PDywfPlylJWZ55G8vr4e27dvh1AoRJcuXZCQkGCxLMbIzMzEDz/8gN/+9rctLqt79+6IjIwEABw86NDoBU6PKYPUgxhj2xhjRYyxZw0k1RmPV0U8AKsFcVCvBOQKguMIGGOBeg7pawOJAGJUloVIT5lGQ45OmjQJly9fxrRp0+Dl5WWx/NosWbIEPj4+8PLywquvvmrW6uGioiLcuXMHN2/eNBoYyJlgjGlccOzfrzdIJgemRZQLIqIxRDQEWnF2jdA8Hm+JgbRmx+PlFoRjqK2tVY9LuRxEhNraWovyMsY6qz7fZoy9DaWlYIzmbaVEZT3onF9JRGIiiiei+DZt2ugsUCAQICAgwAzJjRMZGYl3330XDx8+xKlTp7B8+XKT8+bl5cHX1xdEZDQwkLMxfPhwAHAJ765Hjx5FcXGxQ+o2JaKcqe6+9cXjjYGyaykBgAyquLvakJnxeP39/eHn58cVhB1xB/80ZvuhUUFEV1VfJUT0A2NMn8c7fW3gitmV2pE333wTGzZsQGVlJTp27GhyvmHDhuHIkSMIDQ1Fjx49bCih9fnTn/4Eb29v6FPGzsJ///tfDBkyBHK5HJcuXUJUVJRd6zdFQUgYY9ug9GFiaHWJGEAaY0wKVTxeAClEtFQ1BmHVlSl8LYR9EQqFTdYiPA4wxpKh7B7qDKAYqpcYaubEUgudbUBrfwyUHpGdCm9vb3z88ccYNmyYWV6S//CHP+Cnn37C66+/bkPpbIOvry/+9Kc/4auvvsKCBQuc9t5esWIFAKCxsRELFizAhg0b7CuAvulNAF5Tfb4N5RTXJVDOaHL4NFcioqeffpoGDhxoUloORx8wNAccSFZ/qrZ1cIKwu7Zi2LBh5O/vTwsWLKCNGzcaTPvFF19QZmYmMcbo559/tpOE1kU9XX7YsGGOFkUvDQ0NdOrUKXrrrbfIw8ODpFKp1esw2Ab0HgB6qz4Hae0bpC+9tTZTG0dycjI98cQTll4TDoeIjCqI3qoXo9la+5L1pbfW5igFcfHiRfL09KSIiAgKDAyk8vJynenu3LlDrVq1ouDgYJd+Sbt06RIBIH9/f6dbdHvx4kWqrKzU/L59+zZJJBKb1GWoDegdpCaiH1RfqwCNHyanGaHkXUwcW0NEP5By/c9VxtgSxthiAHccLZet+M1vfoNZs2ahoqICDx8+REZGhlopNmHnzp2or6/HnTt3bLb2wR507doV7dq1w71793D27FlHi6Ohrq4OycnJGDx4sOb6h4aGYtAgZTiexsZGu8miV0EwxjqrZmxkMcZmAxgNZX+sUxAREQGZTGZ2hCQOx1yIaDsRvUNEWUTk1hPn//KXvyAiIgJt2rTBnj17UFDwqPu1vLw8CIVCeHp6Ijk52QFSWg/1bCZn8s00f/58nDlzBh9++CEYYyAiKBQKAMDixYvxzDPPaH7bGkMWxFUAhVC62lgP5UCb7aOem4h6quvNmzcdLAmH4z4EBARg6dKlKC8vR6dOnfDmm282mR588+ZNHDhwAI2NjUhKStJ41nVVxo5Vrmfcvn27kZT24dtvv8Xy5csxbdo0PPfcc3jw4AESExMxYMAANDQ0IDIyEkeOHNGpuG2BwXUQKiUhATAIgAhONAODr4XgcGzDhAkT8NRTT6GmpgYbN25s4r770KFDUCgUVgkM5AwMGDAArVq10jxPHMmvv/6KyZMno0uXLli2bBnq6+uRnJyMgwcP4siRI1i6dCkmTpyI3//+98jKysLDhw9tLpMpC+XSoFzLIANg3lp8G8IVBIdjGxhjWL16Ne7cuQOJRAJAGeMDAEaPHo2XXnoJfn5+emNFuxLe3t4YOnQozpw5o3O8xZ7I5XLEx8fj008/RevWrTFx4kR8/fXXWL9+PcaOHYuFCxfi/PnzWLp0Ka5evYp169bZXih9o9fqDcBUKKf49QIwylj6lm6mzuAoLy8nAJSTk2POgD2H0wS4mDdXezJ16lTy9PSk2bNnU0xMDNXU1FBdXR0FBwfTxIkTHS2e1Vi3bh0BoK1btzpaFCJSeux99dVXCQAtX76ciIgqKyspLCyMEhISqL6+nkQiEbVr144ePnzY4voMtQFTFIR6uusSqNZG2HIztXHU19cTAFqwYIGFl4XD4QrCEJWVlRQUFKSJwRATE0MqTwe0d+9eR4tnNa5du0YAqEuXLg6p/8aNGzRixAiSSqWkUCjoz3/+MwGg+fPnN0n3+eefEwDKzs6m8+fPmxX/wxAtUhBNEgPR5qS3ZDOncYSFhVF6err5V4TDUcEVhGFWrlxJAGjo0KEEgEJCQigsLEwTG9pdCAsLI8aYxXExLEWhUNDQoUPJx8eHzp8/T4sWLSIA9Prrr5NCoXgk7Ysvvkje3t507tw5zf6WBmky1AZ0jkGoPLiuY4zlMMbyVZ9OE3JUTUREBB+D4HBsyPTp0xEbG4tz586hT58+qKmpwdixY63mTdZZGDhwIIjI7k7x1q1bh6+//hrLli3D/v37MX/+fLz88sv4+OOPH/GOyxjD2rVr4efnhylTpkAul2PcuHFISzPFd6Rl6BukLiGiaUQ0HUr/9tOJaBqAd2wmiQXwxXIcjm3x8vLCypUrce3aNXh4eNg0MJAjefXVVwEAmzdvtludFy9exKxZszBkyBD4+fnhjTfewMiRI/HJJ5/Aw0P3ozk8PBwrV67E0aNHsXr1arRv3x6bNm3CmTNnbCOkPtNCvUE5SB2o2pxmDIKIaNKkSdSpUyczDSoO53+AdzGZRHJyMgGg6OjoR7o+3IH6+noSCATUrl07u9U5btw4Cg4Opg0bNpCHhwcNGjSIamtrjeZTKBQ0fPhw8vHxoZMnT5JQKKSkpCSL5TDUBkyZ5iqBMtRotuq706C2IJTnyOFwbMWyZcvg5+eHl19+2WUCA5mDl5cXBg8eDIFAYLfnyYYNG7Bo0SJkZGQgISEBu3btQuvWrY3mY4xh3bp1aNWqFWbNmoW5c+di7969NomOZ8jVxmuqrylQroH4FU60UA5QKoiHDx/i11/1eV/mcOyHPeOy25vo6GhIpVLMmzfP0aLYjJSUFNy4ccN23TUqLl++jAcPHuDf//435syZg+7du2Pv3r3w9/c3uYyoqCisWLEChw4dgqenJzp27IgPP/zQ6rIaigehjgJXSir/M4yxQVaXoAWoF8tVVFQ4rT93zmOFOia1jDGWDa2FpVpx2SVQvnC5HG3btnW0CDZl6NChAICsrCx89dVXNqnj/v37SEpKQlhYGM6dO4fw8HDs37/fIpclqampyM/Px/z58/Hpp5+if//+VpdX3yym3gDGqbxXDmaMLWaMLYFpoRbtBl9NzXEyLI7Lbm7YXY71iYyMREBAAL755hub1ZGZmYnLly/j/Pnz8PPzg0QisdjNB2MM69evh4eHB1avXo2QkBAoFAqrhlE1NAaxmJTeK9VeLN8B8JkBYXWa16r9ItUxnQHbLUUdPpIrCI4TYlZcdjIhJjXH9vTt2xe1tbX48ccfrV72vn37sHbtWvj7+0MgEKC4uBjR0dEtKrNDhw5YtmwZDh48iFWrViE+Ph6rV6+2jsDQ08VE/4sFAcZYIJSO+gClibxbT1n6zOsxUMbylTDGimFkoPvChQsYOHCgQaGHDx+O2bNnazTvnj17MH78eNy+fRspKSkG8wJK0yw1NVWTftasWRgxYgQuXLiA9PR0o/mbp//www/Rr18/fP/995g7d67R/M3T5+bmonv37vjiiy9MChrfPH1hYSHCwsKwadMmbNq0yWj+5um//fZbAMqByC+//NJofu30R48e1XjCzMrKwtGjRw3mDQ0NbZK+qqoKYrEyTHlaWhouXrxoMH+3bt2apA8NDcXixUonw8nJyaiqqjKY/6mnnmqS3spYHJed4xxMnDgRBw4cQE5ODtasWWO1cquqqpCamopWrVoBAIqKiqwWx3vq1KnYtm0b5s+fj969e2PRokVITU1FcHBwi8s2ZRZTOpQ3+K8w7KxPp3mtejOSquLxSnVl1DavGxoaTBRdGSeZMcYHqTnOghhACmMsBaqY1IyxTCIqhNKC4ANlTs6ECRPAGMPXX39t1XJv3LiB+/fvAwC++uorxMU9MofBYtRdTQqFAg0NDZDJZFiyZIl1Ctc3/1W9wURnfQAKtL4X6zieDUBorD5z54B37NiRXn75ZbPycDhqwNdBcJrRpUsX8vHxsVoY0nv37lH//v1JIBDY1IfV6tWrCQD162lF+zgAAAz+SURBVNePvL296dq1ayblM9QGTLEgLhDRdgDjoDST9XFSZSUAzSwF1RvVYiP5LYKvpuZwONbkvffeQ21tLX744QfjiY2wceNGdOzYEYcPH8Y///lPPPfcc1aQUDcZGRkYMGCAZpruqlWrWlymKQriJcZYL1IOVm8wkE6nea36nQWlHyfjHfxm0qFDB5w6dQoymUvOHORwOE7GkCFDwBjDnj17WlROXl4eXn31VVRXV2PLli2a6HW2wsPDA5988gnkcjn69OmDDz74oOWF6jMtmm8AXoNyZpNTmdcnTpwggUDgVv7pOfYDvIuJo4PAwEAKCQmxOH9BQQExxogxRoWFhVaUzDgfffSRJr6FKa47DLUBoxaEypNrDpQDbFYa+bAeCQkJmoUi27Ztc7Q4HA7HDYiJiUF1dTUqKyvNzrtjxw6MGTNGY4XYYLacQd544w089dRTmDFjBqKjozVRAS3BlC6mMlJ6c11GRE45XWju3Lno27cvpk2bhhs3bjhaHA6H4+K8+OKLAID169eble/EiRN45ZVXIBAI8NVXX2H48OG2EM8gAoEAGzduRF1dHWpqapCZmQmFQmFRWUYVBBGZd4UcgJeXF7Zu3Yq6ujpMmTJF3SXG4XA4FjFlyhQA0KzZMYUvv/wSzz77LEJDQ3Hu3DmN6w5H8MQTT2DRokV48OABTp8+jZISg+s09WKKBeESdOvWDcuWLUNRURFycnIcLQ6Hw3Fh2rdvj6CgIJw9exaNjY1G02/evBnPP/88/Pz88N1336Fr1652kNIwf/7znxEfH4+goCB07tzZojLcRkEAyuhXQ4YMwezZs3HhwgVHi8PhcFyYCRMmoKGhASdOnDCYbu3atUhNTQUAiMViREZG2kE643h6emLTpk24d+8eFi5caFEZbqUgGGPYuHEjfHx8MGnSJJizKpvD4XC0ef/99yEQCLB37169aZYsWYIZM2aAMYa8vDy88MILdpTQOLGxsdizZ4/FrsDdSkEASo+M69atw8mTJzFx4kQcOnTIJBORw+FwtAkODkavXr1QUFDwyDEiwjvvvIOsrCwAysFsW69zsJShQ4ciICDAorxupyAAYPTo0cjMzMTu3bvxzDPPIDIyEtOmTUNxcTG3Kjgcjsn88ssvuHDhAioqKjT7GhsbMW3aNGRnZ2PQoEFYsWKFJqa1u+GWCgIAsrOzcevWLXz++ed45pln8Omnn2Lw4MEIDw/HqFGjsHDhQnzxxRf4+eef+awnDoejk8TERADA7t1KJ9b19fUYN24cxGIx5s6di+LiYsycOdORItoU5mwPx/j4eLJ0SpYhamtrUVRUhB07duD48eO4dOmSRjGEhISgV69e6NGjB37zm99otujoaHh5eVldFo7zwBgrJaJ4K5UlhNLtvRSAlIjKtPbHA4iDcl2RwZVLtmoDHPPZv38/hgwZgj/+8Y/4+uuvMWrUKOzfvx8CgQDHjx9Hnz59HC1iizHUBgyFHHUrfHx8MHLkSIwcORIAUFNTgzNnzuDUqVOabevWrbh7964mj6enJ6KjoxEREYGwsDCEhYUhNDRU8z0yMhLt27dHVFQUAgMDHXVqHOfBKjFROM7DH//4R3h4eODEiRNITEzEsWPHAABjx45F7969HSyd7XlsFERzAgIC0K9fP/Tr10+zj4hw69YtXLp0CRcvXsSlS5dw5coVVFZW4tKlSzh69Chu376tM6Sfv78/oqKiEBkZieDgYAiFQgiFQgQFBUEoFCIgIACenp4QCASazdPTE97e3ggMDGyS1s/PD4wxe14OjnVIIKKlqu9NYqIAgKGYKBznxNfXFz179sTZs2c1012TkpKwadMmeHi4bQ+9hsdWQeiCMYa2bduibdu2ePrpp3WmISLU1NSgsrISN27cwPXr13H9+nWUl5fj+vXrqKiowIULFyCTyfDrr7/i3r17ZsshEAjg7+8PX19f+Pj4NNk8PT3R2NgIhULR5LN169YICgpCUFAQAgMDNd9btWoFDw8PMMaafHp6emo2Ly8vjfKqr69HbW0t6urqUFdXh9raWjQ0NCAwMLCJ4gsODkZAQIDeRqLL8Reg9DipvTHGwBiDQqF4JD1jrIky1VauHh4emk/1d7lcjoaGBjQ0NKC+vl7zPTQ01OJZHC1AV3CgdABzdCVmjKVBFfO9Y8eONhSLYy6bN2/G3LlzceDAATz99NMoKCh4bLqenU5B6Ao5OmbMGGRkZODBgwdISkp6JE/zEKLNmT59OsaOHYuff/4ZkyZNeuS4sZCj8+bNg0gkwqlTp3QOSDUPIaruggKUD8qFCxciOjoa//rXv/D3v/+9yUNQoVAgIyMDgYGBOHToEPbt2we5XK55+N+7dw/dunUDYwzXrl1DeXm55qEKKJVaXFwcFAoFzp49i5s3b2ryO9v4kqP4xz/+gdTUVDz33HOora21ZVX6Qo42j4nyiG96lZUhBpRjELYUkmMecXFx2LdvH1auXInU1FT4+vo6WiS74XQKwt1gjCEoKAidOnVCZGQkfHx8HknTv39/dO/eHT4+Pjh79uwjx8ViMTp06ID8/HydbkQ+//xznTGp1W/l27ZtQ+vWrbFhwwbs2rULQNM3/C1btkAul+OTTz7B4cOHm7zh+/n5Ye/evaipqcHChQtx6NAhyOVyjRLy8/PD66+/DgAoKCjA5cuXm5x7SEgIpk+fDsYYtm7dimvXrmnqB4C2bdti0qRJ8PDwwJYtW3Dz5s0m59ahQweMHj0ajY2N2LhxI+7cudNE/o4dO0IkEkGhUCA/Px91dXUaBerh4YGePXvqtQZtgBhAGmNMClVMFAApUCqLLCgtiDLosSI4zgtjzK1nK+nDarOYjMzgeGS/PvgMDo49seYsJmvB2wDHnhhqA9YcZVHP4CgEMNaE/RwOh8NxYqypIBKISN23GmPCfg2MsTTGWAljrOTWrVtWFInD4XA4lmKrMQhdMzj07tceoGOM3WKMXdOTPwzA7ZaLZzaOqJefq33q7eSAeg1SWlp62wnbgDPUz8/dNuhtA9ZUEPpmcOid2aELImqj7xhjrMQR/cWOqJefq93qjbZ3vcZwxjbgDPXzc7d/3dZUEPpmcDTZb8X6OBwOh2NDrKYgVOMMS5vtXtrsk8PhcDgugqutFRc/RvXyc3XfeluCo2V2ZP383O2M03lz5XA4HI5z4GoWhNvCGBMyxkSMsUyt35mMsRTGWJyd6oxhjJUyxtRjSBwO5zHGJVxtmLsa24r1xgAoAFACIFs1E8smqFxEl0AZMwDQ7zralnUCwCCtdSs2oXl8BCivr03/Xx11SmGn/9YaOKoNNKvf5JgWNpIhDcA2W9+fBuqWAhCqFv3as+44KH14wd7X3SUUBOzwsDSAzR+YetDpOtoOjFE5Aiyx4UOoeXyEYtj+/21eZzoc999agiPbAODgmBYqBZWoqteu/5nK0aLUQUpRBKBadd3T7F2/q3QxGV2NbUPGqFZ626Sbx0T0LTy0KkQkJSKxauHio25trVePmIikWvERbP7/6qgTcI7/1lQc2Qb0XT97Eg/gpAPqBZSKKUbV3SuyZ8UqpbSeMZYLYJs96wZcR0FoY5eHJWC/B6YeTmqNA9ilQaoelurrG2KHKnXFR7D1/5sOYI6D/9uWYrc2oAO9MS1shUqBO9p7YYmqa8kR5z4VwBUoPQLbFVdREHZ/WAIOeWCOAZCoOlcxgBSVeWvLBYbadW4DEK8yZW3aEJrFR7DL/6tdpwP+25bikDagTbP/zJ7EQGlBJACw6xu8iisOqFONiIjKVN3NVfau3CWmuTpwkFo9MBcDZf+rUw9kugqqB00WgGoo+9IXw/aD1LrqdJn/1gkGqZtcPyKy95u0EMpJBQXqEK52rlt97WX2HItQWRDqbr0Qe4+DuISC4HA4HI79cZUuJg6Hw+HYGa4gOBwOh6MTriA4HA6HoxOuIJwExliceiGMpW4u1Pm0y+JwOBxL4YPUToY6jobWKmqb5uNw3AX1gkd7z/ByZ1zF1Ybbo1qhGQfVymLVzS7D/+Z9S1THE1W/c6GcpimEcs1EnFa+EABxRLRUpThEUE5PlKnSj1Xlj+MKheMucMVgfXgXk5Ogmt8cCqUiOKm62bOhXEFaAtV8fSjnYaerjlersou082mVBSgd0YlVq0DTVZ/VqjRd7HV+HI4tUXWrZjtaDneDKwgnRXvVrEoZqF0NVKmOZ0NpbZTpyacL9cphV3FQx+GYihSOdUHilnAF4SSouphiVA7ZQqG0GOZAGc9b3c0kgtLdAKBc/h+j2tTdTqFQOhUT4X8O3eao3EqkAMhWd2WpFEk8j/vAcSNkWu5TOFaAD1JzOByXRz1JA0Chs7tNcSW4guBwOByOTngXE4fD4XB0whUEh8PhcHTCFQSHw+FwdMIVBIfD4XB0whUEh8PhcHTCFQSHw+FwdMIVBIfD4XB08v/MS0k3AKS+/QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 394.92x129.6 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(1, 2)\n",
    "fig.set_size_inches(5.485, 1.8)\n",
    "\n",
    "ax[0].axhline(loss_fn(torch.nn.Softmax()(Xval @ theta + theta0) * Xval[:,0][:,None], Yval).item(), linestyle='-.', c='black', label='LR')\n",
    "ax[0].plot(val_losses, c='black', label='COM')\n",
    "ax[0].axhline(loss(Xval, Yval, alpha_true).item(), linestyle='--', c='black', label='true')\n",
    "ax[0].legend()\n",
    "ax[0].set_xlabel(\"iteration\")\n",
    "ax[0].set_ylabel(\"validation loss\")\n",
    "\n",
    "ax[1].plot(alpha.detach().numpy(), c='black', label='learned')\n",
    "ax[1].plot(alpha_true.detach().numpy(), '--', c='black', label='true')\n",
    "ax[1].set_xlabel(\"$i$\")\n",
    "ax[1].set_ylabel(\"$\\\\theta_i$\")\n",
    "ax[1].legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"figures/resource_allocation.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
